/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
 */
#ifndef MODEL_H
#define MODEL_H
#include <vector>
#include <string>
#define SCANN_API_PUBLIC __attribute__((visibility("default")))

namespace research_scann {

union Entry {
    int missing;
    double fvalue;
    int qvalue;

    enum class Type { MISSING, FVALUE, QVALUE };
};

class SCANN_API_PUBLIC IadpModel {
public:
    enum class AMODE { DISABLE, COLLECT, INFERENCE };
    virtual void SetMode(AMODE m) = 0;
    virtual AMODE GetMode() = 0;
    virtual void SetProbeInfo(std::vector<float> pi) = 0;
    virtual std::vector<float> GetProbeInfo() = 0;
    virtual void SetTrain75p(int t75) = 0;
    virtual int GetTrain75p() = 0;
    virtual bool SetSearchParamsAndMode(float thd, int prefine, int pmax) = 0;
    virtual bool PredictOne(std::vector<Entry> &data, int &pred_result) = 0;
    virtual void trainModel(std::vector<int64_t> &epb, std::vector<std::vector<Entry>> &data, int n_leaves) = 0;
    virtual bool LoadLibrary(std::string libpath) = 0;
    virtual bool SaveLibrary(std::string libpath) = 0;
    virtual std::string CreateUuid() = 0;
    virtual std::string GetWorkPath() = 0;
    virtual ~IadpModel() = default;
};
SCANN_API_PUBLIC extern std::unique_ptr<IadpModel> pAdaptiveModel;
SCANN_API_PUBLIC std::unique_ptr<IadpModel> adpModelFactory();

constexpr int ADP_GROUP_NUM = 8;
constexpr int ADP_CATEGORY_NUM = 2;

// entryList: vector<Entry>   data: vector<KMeansTreeSearchResult>  g_num: datanum(<= entryList.size())
#define ADP_COLLECT_DATA(entryList, data, g_num)                                     \
    do {                                                                             \
        (entryList).resize((g_num) *ADP_CATEGORY_NUM);                                   \
        for (size_t j = 0; j < (g_num); j++) {                                         \
            (entryList)[ADP_CATEGORY_NUM * j].qvalue = (data)[j].first;         \
            (entryList)[ADP_CATEGORY_NUM * j + 1].fvalue = (data)[j].second; \
        }                                                                            \
    } while (0)

#define ADP_COLLECT_DATA_PRED(entryList, data) ADP_COLLECT_DATA(entryList, data, ADP_GROUP_NUM)

}  // namespace research_scann
#endif